import json
import re

MAIN_TOKENS = [
    "op","to","fo","de","cn","es","dr","cv",
    "dd","gg","ss","uu","ad","nm","mu","ms",
    "bq","qa","dp","da","ma","cc","he","re",
    "vx","ee","se","rr","cy","un","vt","ar",
    "bo","ft","df","vh","th","pp","ra","bu",
    "fa","rf","dr","fg","xv","ho","pe","dq",
    "go","gc","cu","si","qq","be","ke","cs",
    "vs","rs","sr","ct","gt","kh","gn","hp"
]
POS_0_TOKENS = MAIN_TOKENS[48:]
PADDING_TOKENS = ["fo", "dp", "df", "cu"]

# These regex patterns will be used to clean up the input lines before decoding
TO_DISCARD = [r'^it[a-z]{10}', r'^cc[a-z]{10}', r'^zzad', r'ii']

# ==============================================
# UTILS
# ==============================================

def calculate_encoding_plan(num_bytes):  
    if num_bytes == 0:
        return 1, [(POS_0_TOKENS, 0)]  # Just "cu"
    
    # Calculate base number of regular tokens
    regular_tokens = num_bytes + 1 + (num_bytes // 3)
    
    # Check if we need padding
    needs_padding = num_bytes % 3 == 0

    if needs_padding:
        regular_tokens -= 1  # We'll add padding later
    
    total_bits = num_bytes * 8
    
    # Build bit allocations for regular (non-padding) tokens first
    bit_allocations = []
    bits_allocated = 0
    
    for token_idx in range(regular_tokens):
        remaining_bits = total_bits - bits_allocated
        
        if token_idx == 0:
            # First position uses POS_0_TOKENS (4 bits)
            bits_to_allocate = 4
            bit_allocations.append((POS_0_TOKENS, bits_to_allocate))
        else:
            # Regular positions use 6 bits from MAIN_TOKENS (or less if not enough)
            bits_to_allocate = min(6, remaining_bits)
            bit_allocations.append((MAIN_TOKENS, bits_to_allocate))
        
        bits_allocated += bits_to_allocate
    
    # Add padding at the end if needed
    if needs_padding:
        bit_allocations.append((PADDING_TOKENS, 2))
        bits_allocated += 2
    
    num_tokens = len(bit_allocations)
    return num_tokens, bit_allocations

def adjust_to_legacy(tokens):
    # Mapping offsets to match legacy behavior
    if tokens[0] == "si":
        tokens[0] = "de"
    if tokens[-1] == "dp":
        tokens[-1] = "bq"
    return tokens

def bytes_to_big_endian(byte_array):
    """Convert byte array to single integer value (big-endian)"""
    value = 0
    for i, byte in enumerate(byte_array):
        shift = (len(byte_array) - 1 - i) * 8
        value |= (byte << shift)
    return value

def big_endian_to_tokens(big_endian_value, bit_allocations, num_bytes):
    """
    Convert a big_endian_value to tokens using the bit allocation plan.
    Extracts bits from MSB to LSB.
    """
    tokens = [] # Resulting tokens
    total_bits = num_bytes * 8
    bit_position = total_bits  # Start from MSB
    
    for group, num_bits in bit_allocations:
        if num_bits == 0:
            # Special case for 0 bytes
            tokens.append(group[2])  # "cu"
            continue
            
        bit_position -= num_bits
        mask = (1 << num_bits) - 1 # creates "num_bits" ones, e.g. num_bits = 4 => 1111
        index = (big_endian_value >> bit_position) & mask
        
        # Determine if we need a step multiplier
        # Last regular (non-padding) position with 4 bits uses step 4
        is_last_regular = (bit_allocations.index((group, num_bits)) == len(bit_allocations) - 1 and group != PADDING_TOKENS)
        needs_step_4 = (group == MAIN_TOKENS and num_bits == 4 and is_last_regular)
        
        if needs_step_4:
            index *= 4
        
        tokens.append(group[index % len(group)])
    
    return tokens

# ==============================================
# ENCODE / DECODE
# ==============================================

def encode_bytes(byte_array):
    if len(byte_array) == 0:
        return ["cu"]

    # Calculate encoding plan based on byte length
    _, bit_allocations = calculate_encoding_plan(len(byte_array))

    # Convert bytes to big endian integer
    value = bytes_to_big_endian(byte_array)

    # Convert entire byte array into bitstring
    total_bits = len(byte_array) * 8
    bits = format(value, f"0{total_bits}b")

    tokens = []
    bit_position = 0

    # IMPORTANT: The legacy encoding extracts bits starting from the *least significant* end.
    # So we read allocations in reverse order.
    for group, num_bits in reversed(bit_allocations):
        if num_bits == 0:
            tokens.append(group[2])  # special case: "cu"
            continue

        # Take bits from the RIGHT side (least significant)
        chunk = bits[-(bit_position + num_bits): len(bits) - bit_position]
        bit_position += num_bits

        index = int(chunk, 2)

        # Handle the special step multiplier for 4-bit groups
        is_last_regular = (group != PADDING_TOKENS and num_bits == 4)
        if group == MAIN_TOKENS and num_bits == 4 and is_last_regular:
            index *= 4

        tokens.append(group[index % len(group)])

    # Reverse back into final token order
    tokens.reverse()
    return tokens

def decode_tokens(tokens, num_tokens):
    if num_tokens == 1:
        return []

    _, bit_allocations = calculate_encoding_plan(num_tokens)
    bits = ""

    for token, (group, n_bits) in zip(tokens, bit_allocations):
        # Check if token exists in group, otherwise fallback to MAIN_TOKENS
        if token in group:
            index = group.index(token)
        else:
            # fallback: search in MAIN_TOKENS
            index = MAIN_TOKENS.index(token)

        # Handle last regular token with 4 bits
        is_last_regular = group != PADDING_TOKENS and (token == tokens[-1] or n_bits == 4)
        if group == MAIN_TOKENS and n_bits == 4 and is_last_regular:
            index = index // 4

        bits_part = format(index, f"0{n_bits}b")
        bits += bits_part

    # Convert bitstring to bytes
    result = []
    for i in range(0, len(bits), 8):
        chunk = bits[i:i+8]
        if len(chunk) < 8:
            break
        result.append(int(chunk, 2))
    return result

# ==============================================
# TEST EXAMPLE
# ==============================================

def encode_mode(filename):
    with open(filename, 'r') as f:
        data = json.load(f)

    header = data.get('header', '')
    immobili = data.get('immobili', [])
    footers = data.get('footers', [])
    output_filename = filename.rsplit('.', 1)[0] + '.amn'

    encoded_file = ""

    # Encode header
    header_bytes = header.encode('utf-8')
    encoded_header = f"zzad{''.join(adjust_to_legacy(encode_bytes(header_bytes)))}"
    encoded_file += encoded_header + "\n"

    # Encode immobili
    for immobile in immobili:
        immobile_bytes = f"{json.dumps(immobile)}#|#100#|#100#|#".encode('utf-8')
        encoded_immobile = f"itopopopopop{''.join(adjust_to_legacy(encode_bytes(immobile_bytes)))}"
        encoded_file += encoded_immobile + "\n"

    # Encode footers
    for footer in footers:
        footer_bytes = footer.encode('utf-8')
        encoded_footer = f"ccopopopopop{''.join(adjust_to_legacy(encode_bytes(footer_bytes)))}"
        encoded_file += encoded_footer + "\n"

    with open(output_filename, 'w') as f:
        f.write(encoded_file)

    print(f"File saved to {output_filename}")

def decode_mode(filename):
    output_filename = filename.rsplit('.', 1)[0] + '.json'

    with open(filename, 'r') as f:
        encoded_lines = f.read().strip().splitlines()    

    decoded_json = {}

    for encoded_line in encoded_lines:
        # Determine the key based on prefix
        if encoded_line.startswith('zzad'):
            key = 'header'
        elif encoded_line.startswith('it'):
            key = 'immobili'
            decoded_json.setdefault(key, [])
        elif encoded_line.startswith('cc'):
            key = 'footers'
            decoded_json.setdefault(key, [])

        # Clean up the line by removing unwanted patterns
        for discard in TO_DISCARD:
            pattern = re.compile(discard, re.IGNORECASE)
            encoded_line = pattern.sub('', encoded_line)

        # Split the token string into tokens of length 2
        spaced = ' '.join(encoded_line[i:i+2] for i in range(0, len(encoded_line), 2))
        # Convert to list
        tokens = spaced.split(' ')
        decoded = decode_tokens(tokens, len(tokens))
        decoded_string = ''.join(chr(c) for c in decoded)


        if key == 'header':
            decoded_json[key] = decoded_string
        elif key == 'immobili':
            # remove #|#100#|#100#|# from immobili
            decoded_string = decoded_string.replace('#|#100#|#100#|#', '')
            decoded_json[key].append(json.loads(decoded_string))
        else:
            decoded_json[key].append(decoded_string)

    with open(output_filename, 'w') as f:
        f.write(json.dumps(decoded_json, indent=4, ensure_ascii=False))

    print(f"File saved to {output_filename}")

# --- CLI setup ---
def main():
    import argparse

    parser = argparse.ArgumentParser(description="Encode or decode AMN files")
    parser.add_argument('mode', choices=['encode', 'decode'], help="Mode of operation")
    parser.add_argument('file', type=str, help="Input filename (JSON for encode, .amn for decode)")
    args = parser.parse_args()

    if args.mode == 'encode':
        encode_mode(args.file)
    else:
        decode_mode(args.file)


if __name__ == "__main__":
    main()